package com.yeskery.nut.websocket;

import com.yeskery.nut.annotation.websocket.*;
import com.yeskery.nut.bind.BindObject;
import com.yeskery.nut.bind.FitValueHelper;
import com.yeskery.nut.util.ExceptionUtils;
import com.yeskery.nut.util.ReflectUtils;
import com.yeskery.nut.util.StringUtils;

import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * 基于注解的WebSocket处理器的动态代理处理对象
 * @author sprout
 * @version 1.0
 * 2023-04-15 21:53
 */
public class AnnotationWebSocketHandlerInvocationHandler implements InvocationHandler {

    /** 日志对象 */
    private static final Logger logger = Logger.getLogger(AnnotationWebSocketHandlerInvocationHandler.class.getName());

    /** WebSocket的处理方法名数组 */
    private static final String[] WEB_SOCKET_HANDLE_METHODS = {WebSocketHandler.ON_OPEN_METHOD_NAME,
            WebSocketHandler.ON_CLOSE_METHOD_NAME, WebSocketHandler.ON_MESSAGE_METHOD_NAME,
            WebSocketHandler.ON_ERROR_METHOD_NAME};

    /** 文本消息后缀 */
    private static final String TEXT_MESSAGE_POSTFIX = "$_text";

    /** 二进制i消息后缀 */
    private static final String BINARY_MESSAGE_POSTFIX = "$_binary";

    /** 默认不存在的方法 */
    private static final Object EMPTY_METHOD = new Object();

    /** 目标对象 */
    private final Object target;

    /** 处理方法缓存map */
    private final Map<String, Object> handleMethodCacheMap = new HashMap<>(WEB_SOCKET_HANDLE_METHODS.length);

    /**
     * 构建WebSocket处理器的动态代理处理对象
     * @param target 目标对象
     */
    public AnnotationWebSocketHandlerInvocationHandler(Object target) {
        this.target = target;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        try {
            String methodName = method.getName();
            if (Arrays.stream(WEB_SOCKET_HANDLE_METHODS).noneMatch(methodName::equals)) {
                return method.invoke(target, args);
            }

            String cacheKey = methodName;
            if (WebSocketHandler.ON_MESSAGE_METHOD_NAME.equals(cacheKey) && method.getParameterCount() > 0) {
                Class<?> parameterType = method.getParameterTypes()[0];
                cacheKey = String.class.isAssignableFrom(parameterType) ? cacheKey + TEXT_MESSAGE_POSTFIX : cacheKey + BINARY_MESSAGE_POSTFIX;
            }
            final String cacheMapKey = cacheKey;
            Object methodObject = handleMethodCacheMap.computeIfAbsent(cacheKey, n -> findWebSocketHandlerMethod(cacheMapKey, target));
            if (methodObject != EMPTY_METHOD) {
                invokeWebSocketHandlerMethod((Method) methodObject, target, args);
            }
            return null;
        } catch (Throwable e) {
            logger.logp(Level.SEVERE, this.getClass().getName(), "invoke", "WebSocket Handler Method Invoke Fail.", e);
            throw e;
        }
    }

    /**
     * 查找WebSocket处理器方法
     * @param cacheKey 缓存Key
     * @param handler 处理器对象
     * @return 处理方法对象
     */
    private Object findWebSocketHandlerMethod(String cacheKey, Object handler) {
        switch (cacheKey) {
            case WebSocketHandler.ON_OPEN_METHOD_NAME:
                return doFindWebSocketHandlerMethod(handler.getClass(), cacheKey, OnOpen.class);
            case WebSocketHandler.ON_CLOSE_METHOD_NAME:
                return doFindWebSocketHandlerMethod(handler.getClass(), cacheKey, OnClose.class);
            case WebSocketHandler.ON_MESSAGE_METHOD_NAME + TEXT_MESSAGE_POSTFIX:
            case WebSocketHandler.ON_MESSAGE_METHOD_NAME + BINARY_MESSAGE_POSTFIX:
                return doFindWebSocketHandlerMethod(handler.getClass(), cacheKey, OnMessage.class);
            case WebSocketHandler.ON_ERROR_METHOD_NAME:
                return doFindWebSocketHandlerMethod(handler.getClass(), cacheKey, OnError.class);
            default: return EMPTY_METHOD;
        }

    }

    /**
     * 执行查找WebSocket处理器方法
     * @param clazz 处理器类对象
     * @param cacheKey 缓存Key
     * @param annotationClass 处理器处理注解
     * @return 处理方法对象
     */
    private Object doFindWebSocketHandlerMethod(Class<?> clazz, String cacheKey, Class<? extends Annotation> annotationClass) {
        Method[] methods = ReflectUtils.getBeanAnnotationMethod(clazz, annotationClass);
        if (methods.length == 0) {
            return EMPTY_METHOD;
        }
        boolean textMessageMethod;
        if ((textMessageMethod = cacheKey.endsWith(TEXT_MESSAGE_POSTFIX)) || cacheKey.endsWith(BINARY_MESSAGE_POSTFIX)) {
            Method method = textMessageMethod ? doFindWebSocketHandlerOnMessageMethod(methods, true)
                    : doFindWebSocketHandlerOnMessageMethod(methods, false);
            return method == null ? EMPTY_METHOD : method;
        }
        return methods[0];
    }

    /**
     * 查找WebSocket处理器OnMessage方法
     * @param methods 方法数组
     * @param textMessage 是否是查找文本消息方法
     * @return WebSocket处理器OnMessage方法
     */
    private Method doFindWebSocketHandlerOnMessageMethod(Method[] methods, boolean textMessage) {
        Method onMessageMethod = null;
        for (Method method : methods) {
            if (method.getParameterCount() < 1) {
                onMessageMethod = method;
            } else {
                Class<?> parameterType = method.getParameterTypes()[0];
                if (textMessage) {
                    if (String.class.isAssignableFrom(parameterType)) {
                        return method;
                    }
                } else {
                    if (byte[].class.isAssignableFrom(parameterType)) {
                        return method;
                    }
                }
            }
        }
        return onMessageMethod;
    }

    /**
     * 执行WebSocket处理器方法
     * @param method 方法对象
     * @param handler 处理器对象
     * @param args 方法参数
     * @throws Throwable 异常对象
     */
    private void invokeWebSocketHandlerMethod(Method method, Object handler, Object[] args) throws Throwable {
        int parameterCount = method.getParameterCount();
        if (parameterCount == 0) {
            method.invoke(handler);
            return;
        }

        Object[] parameters = new Object[parameterCount];
        Session session = (Session) getParameter(args, Session.class);
        if (session == null) {
            throw new WebSocketException("WebSocket Session Obtain Fail.");
        }
        Class<?>[] parameterTypes = method.getParameterTypes();
        for (int i = 0; i < parameterTypes.length; i++) {
            Class<?> parameterType = parameterTypes[i];
            if (parameterType.isAssignableFrom(Session.class)) {
                parameters[i] = session;
            } else if (Throwable.class.isAssignableFrom(parameterTypes[i])) {
                Object throwable = getParameter(args, Throwable.class);
                throwable = ExceptionUtils.getTargetThrowable((Throwable) throwable);
                if (parameterTypes[i].isAssignableFrom(throwable.getClass())) {
                    parameters[i] = throwable;
                }
            } else {
                Parameter parameter = method.getParameters()[i];
                PathParam pathParam = parameter.getAnnotation(PathParam.class);
                if (pathParam == null) {
                    parameters[i] = getParameter(args, parameterType);
                } else {
                    String pathName = StringUtils.isEmpty(pathParam.value()) ? parameter.getName() : pathParam.value();
                    String value = session.getPathParameter(pathName);
                    BindObject bindObject = FitValueHelper.getInstance().getFitParamValue(pathName, value, parameterType);
                    if (!bindObject.isEmpty()) {
                        parameters[i] = bindObject.getData();
                    }
                }
            }
        }
        method.invoke(handler, parameters);
    }

    /**
     * 从参数列表中获取需要的参数对象
     * @param args 参数数组
     * @param clazz 需要获取的参数类型
     * @return 查找到的参数对象
     */
    private Object getParameter(Object[] args, Class<?> clazz) {
        for (Object object : args) {
            if (clazz.isAssignableFrom(object.getClass())) {
                return object;
            }
        }
        return null;
    }
}
